Skip to content

Commit 5f5ba73

Browse files
author
Devin Lu
committed
refactor(eda): ran just ci
1 parent c801377 commit 5f5ba73

File tree

4 files changed

+55
-31
lines changed

4 files changed

+55
-31
lines changed

dataprep/clean/clean_country.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def _format_country(
271271
return result, 2 if val != result else 3
272272

273273

274-
@lru_cache(maxsize=2**20)
274+
@lru_cache(maxsize=2 ** 20)
275275
def _check_country(country: str, input_formats: Tuple[str, ...], strict: bool, clean: bool) -> Any:
276276
"""
277277
Finds the index of the given country in the DATA dataframe.
@@ -322,7 +322,7 @@ def _check_country(country: str, input_formats: Tuple[str, ...], strict: bool, c
322322
return (None, "unknown") if clean else False
323323

324324

325-
@lru_cache(maxsize=2**20)
325+
@lru_cache(maxsize=2 ** 20)
326326
def _check_fuzzy_dist(country: str, fuzzy_dist: int) -> Any:
327327
"""
328328
A match is found if a country has an edit distance <= fuzzy_dist

dataprep/eda/create_diff_report/diff_formatter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def format_diff_report(
6868
cfg: Config,
6969
mode: Optional[str],
7070
progress: bool = True,
71-
target: Optional[str] = None
71+
target: Optional[str] = None,
7272
) -> Dict[str, Any]:
7373
"""
7474
Format the data and figures needed by create_diff_report
@@ -118,6 +118,7 @@ def format_diff_report(
118118
raise ValueError(f"Unknown mode: {mode}")
119119
return report
120120

121+
121122
def validate_target(target: str, df_list: List[pd.DataFrame]):
122123
"""
123124
Helper function, verify that target column exists
@@ -128,7 +129,8 @@ def validate_target(target: str, df_list: List[pd.DataFrame]):
128129
exists = True
129130
break
130131
if not exists:
131-
raise ValueError(f'Sorry, {target} is not a valid column')
132+
raise ValueError(f"Sorry, {target} is not a valid column")
133+
132134

133135
def format_basic(df_list: List[pd.DataFrame], target: Optional[str], cfg: Config) -> Dict[str, Any]:
134136
"""
@@ -295,7 +297,9 @@ def compute_plot_data(
295297
elif is_dtype(dtp, DateTime_v1()):
296298
plot_data.append((col, dtp, dask.compute(*datum), orig)) # workaround
297299

298-
return Intermediate(data=plot_data, stats=stats, visual_type="comparison_grid", target=target, df_list=pd_list)
300+
return Intermediate(
301+
data=plot_data, stats=stats, visual_type="comparison_grid", target=target, df_list=pd_list
302+
)
299303

300304

301305
def _compute_variables(df: EDAFrame, cfg: Config) -> Dict[str, Any]:

dataprep/eda/diff/render.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def bar_viz(
8484
df_labels: List[str],
8585
baseline: int,
8686
target: Optional[str] = None,
87-
df_list: Optional[List[pd.DataFrame]] = None
87+
df_list: Optional[List[pd.DataFrame]] = None,
8888
) -> Figure:
8989
"""
9090
Render a bar chart
@@ -119,7 +119,7 @@ def bar_viz(
119119
tools="hover",
120120
x_range=list(df[baseline].index),
121121
y_axis_type=yscale,
122-
y_range=(min(col1_min, col2_min) * (1 - y_inc), max(col1_max, col2_max) * (1 + y_inc))
122+
y_range=(min(col1_min, col2_min) * (1 - y_inc), max(col1_max, col2_max) * (1 + y_inc)),
123123
)
124124
row_names = None
125125
offset = np.linspace(-0.08 * len(df), 0.08 * len(df), len(df)) if len(df) > 1 else [0]
@@ -157,7 +157,7 @@ def bar_viz(
157157

158158
if show_yticks and yscale == "linear":
159159
_format_axis(fig, 0, df[baseline].max(), "y")
160-
160+
161161
df1, df2 = df_list[0], df_list[1]
162162
if target != col and target and col in df1.columns and col in df2.columns:
163163
col1, col2 = df_list[0][col], df_list[1][col]
@@ -166,12 +166,23 @@ def bar_viz(
166166
for names in row_names:
167167
row_avgs_1.append(df_list[0][target][col1 == names].mean())
168168
row_avgs_2.append(df_list[1][target][col2 == names].mean())
169-
169+
170170
row_avgs_1 = [0 if math.isnan(x) else x for x in row_avgs_1]
171171
row_avgs_2 = [0 if math.isnan(x) else x for x in row_avgs_2]
172-
fig.extra_y_ranges = {"Averages": Range1d(start=min(row_avgs_1 + row_avgs_2) * (1 - y_inc), end=max(row_avgs_1 + row_avgs_2) * (1 + y_inc))}
173-
fig.multi_line([row_names, row_names], [row_avgs_1, row_avgs_2], color=['navy', 'firebrick'], y_range_name="Averages", line_width=4)
174-
fig.add_layout(LinearAxis(y_range_name="Averages"), 'right')
172+
fig.extra_y_ranges = {
173+
"Averages": Range1d(
174+
start=min(row_avgs_1 + row_avgs_2) * (1 - y_inc),
175+
end=max(row_avgs_1 + row_avgs_2) * (1 + y_inc),
176+
)
177+
}
178+
fig.multi_line(
179+
[row_names, row_names],
180+
[row_avgs_1, row_avgs_2],
181+
color=["navy", "firebrick"],
182+
y_range_name="Averages",
183+
line_width=4,
184+
)
185+
fig.add_layout(LinearAxis(y_range_name="Averages"), "right")
175186
return fig
176187

177188

@@ -186,7 +197,7 @@ def hist_viz(
186197
df_labels: List[str],
187198
orig: Optional[List[str]] = None,
188199
target: Optional[str] = None,
189-
df_list: Optional[List[pd.DataFrame]] = None
200+
df_list: Optional[List[pd.DataFrame]] = None,
190201
) -> Figure:
191202
"""
192203
Render a histogram
@@ -222,14 +233,13 @@ def hist_viz(
222233
counts_max_2 = max(counts_list[1])
223234

224235
y_start, y_end = min(counts_min_1, counts_min_2), max(counts_max_1, counts_max_2)
225-
226236

227237
fig = Figure(
228238
plot_height=plot_height,
229239
plot_width=plot_width,
230240
title=col,
231241
toolbar_location=None,
232-
y_axis_type=yscale
242+
y_axis_type=yscale,
233243
)
234244
bins_list = []
235245
for i, hst in enumerate(hist):
@@ -252,7 +262,9 @@ def hist_viz(
252262
bottom = 0 if yscale == "linear" or df.empty else counts.min() / 2
253263
if y_start is not None and y_end is not None:
254264
# fig.y_range = (y_start * (1 - y_inc), y_end * (1 + y_inc))
255-
fig.extra_y_ranges = {"Counts": Range1d(start=y_start * (1 - y_inc), end=y_end * (1 + y_inc))}
265+
fig.extra_y_ranges = {
266+
"Counts": Range1d(start=y_start * (1 - y_inc), end=y_end * (1 + y_inc))
267+
}
256268
fig.quad(
257269
source=df,
258270
left="left",
@@ -262,7 +274,7 @@ def hist_viz(
262274
top="freq",
263275
fill_color=CATEGORY10[i],
264276
line_color=CATEGORY10[i],
265-
y_range_name="Counts"
277+
y_range_name="Counts",
266278
)
267279
else:
268280
fig.quad(
@@ -273,11 +285,11 @@ def hist_viz(
273285
alpha=0.5,
274286
top="freq",
275287
fill_color=CATEGORY10[i],
276-
line_color=CATEGORY10[i]
288+
line_color=CATEGORY10[i],
277289
)
278290
# if col == 'LotFrontage':
279-
# breakpoint()
280-
291+
# breakpoint()
292+
281293
hover = HoverTool(tooltips=tooltips, attachment="vertical", mode="vline")
282294
fig.add_tools(hover)
283295

@@ -325,9 +337,17 @@ def hist_viz(
325337
max_range = max(df1_bin_averages + df2_bin_averages)
326338
min_range = min(df1_bin_averages + df2_bin_averages)
327339

328-
fig.extra_y_ranges['Averages'] = Range1d(start=min_range * (1 - y_inc), end=max_range * (1 + y_inc))
329-
fig.multi_line([bins_1, bins_2], [df1_bin_averages, df2_bin_averages], color=['navy', 'firebrick'], y_range_name="Averages", line_width=4)
330-
fig.add_layout(LinearAxis(y_range_name="Averages", axis_label='Bin Averages'), 'right')
340+
fig.extra_y_ranges["Averages"] = Range1d(
341+
start=min_range * (1 - y_inc), end=max_range * (1 + y_inc)
342+
)
343+
fig.multi_line(
344+
[bins_1, bins_2],
345+
[df1_bin_averages, df2_bin_averages],
346+
color=["navy", "firebrick"],
347+
y_range_name="Averages",
348+
line_width=4,
349+
)
350+
fig.add_layout(LinearAxis(y_range_name="Averages", axis_label="Bin Averages"), "right")
331351
return fig
332352

333353

@@ -678,7 +698,7 @@ def format_num_stats(data: Dict[str, List[Any]]) -> Dict[str, Dict[str, List[Any
678698
descriptive = {
679699
"Mean": data["mean"],
680700
"Standard Deviation": data["std"],
681-
"Variance": [std**2 for std in data["std"]],
701+
"Variance": [std ** 2 for std in data["std"]],
682702
"Sum": [mean * npres for mean, npres in zip(data["mean"], data["npres"])],
683703
"Skewness": [float(skew) for skew in data["skew"]],
684704
"Kurtosis": [float(kurt) for kurt in data["kurt"]],
@@ -734,7 +754,7 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
734754
df_labels,
735755
baseline if len(df) > 1 else 0,
736756
target,
737-
df_list
757+
df_list,
738758
)
739759
elif is_dtype(dtp, Continuous()):
740760
if cfg.diff.density:
@@ -753,7 +773,7 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
753773
df_labels,
754774
orig,
755775
target,
756-
df_list
776+
df_list,
757777
)
758778
elif is_dtype(dtp, DateTime()):
759779
df, timeunit = data

dataprep/eda/intermediate.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
3030
visual_type = kwargs.pop("visual_type")
3131
super().__init__(**kwargs)
3232
self.visual_type = visual_type
33-
if 'target' in kwargs:
34-
self.target = kwargs.pop('target')
35-
36-
if 'df_list' in kwargs:
37-
self.df_list = kwargs.pop('df_list')
33+
if "target" in kwargs:
34+
self.target = kwargs.pop("target")
35+
36+
if "df_list" in kwargs:
37+
self.df_list = kwargs.pop("df_list")
3838
else:
3939
raise ValueError("Unsupported initialization")
4040

0 commit comments

Comments
 (0)