1
1
"""
2
2
This module implements the visualization for the plot_diff function.
3
3
""" # pylint: disable=too-many-lines
4
+ from turtle import color
4
5
from typing import Any , Dict , List , Tuple , Optional
5
-
6
+ from sklearn . preprocessing import MinMaxScaler
6
7
import math
7
8
import numpy as np
8
9
import pandas as pd
10
+ import dask .array as da
11
+ import matplotlib .pyplot as plt
9
12
from bokeh .models import (
10
13
HoverTool ,
11
14
Panel ,
12
15
FactorRange ,
13
16
)
14
- from bokeh .plotting import Figure , figure
17
+ from bokeh .plotting import Figure , figure , show
15
18
from bokeh .transform import dodge
16
19
from bokeh .layouts import row
20
+ from bokeh .models .ranges import Range1d
21
+ from bokeh .models import LinearAxis
17
22
18
23
from ..configs import Config
19
24
from ..dtypes import Continuous , DateTime , Nominal , is_dtype
@@ -78,6 +83,8 @@ def bar_viz(
78
83
orig : List [str ],
79
84
df_labels : List [str ],
80
85
baseline : int ,
86
+ target : Optional [str ] = None ,
87
+ df_list : Optional [List [pd .DataFrame ]] = None
81
88
) -> Figure :
82
89
"""
83
90
Render a bar chart
@@ -94,6 +101,12 @@ def bar_viz(
94
101
("Source" , "@orig" ),
95
102
]
96
103
104
+ col1_min = df [0 ][col ].min ()
105
+ col2_min = df [1 ][col ].min ()
106
+ col1_max = df [0 ][col ].max ()
107
+ col2_max = df [1 ][col ].max ()
108
+ y_inc = 0.05
109
+
97
110
if show_yticks :
98
111
if len (df [baseline ]) > 10 :
99
112
plot_width = 28 * len (df [baseline ])
@@ -106,12 +119,15 @@ def bar_viz(
106
119
tools = "hover" ,
107
120
x_range = list (df [baseline ].index ),
108
121
y_axis_type = yscale ,
122
+ y_range = (min (col1_min , col2_min ) * (1 - y_inc ), max (col1_max , col2_max ) * (1 + y_inc ))
109
123
)
110
-
124
+ row_names = None
111
125
offset = np .linspace (- 0.08 * len (df ), 0.08 * len (df ), len (df )) if len (df ) > 1 else [0 ]
112
126
for i , (nrow , data ) in enumerate (zip (nrows , df )):
113
127
data ["pct" ] = data [col ] / nrow * 100
114
128
data .index = [str (val ) for val in data .index ]
129
+ if row_names is None :
130
+ row_names = data .index
115
131
data ["orig" ] = orig [i ]
116
132
117
133
fig .vbar (
@@ -126,7 +142,6 @@ def bar_viz(
126
142
tweak_figure (fig , "bar" , show_yticks )
127
143
128
144
fig .yaxis .axis_label = "Count"
129
-
130
145
x_axis_label = ""
131
146
if ttl_grps > len (df [baseline ]):
132
147
x_axis_label += f"Top { len (df [baseline ])} of { ttl_grps } { col } "
@@ -142,6 +157,21 @@ def bar_viz(
142
157
143
158
if show_yticks and yscale == "linear" :
144
159
_format_axis (fig , 0 , df [baseline ].max (), "y" )
160
+
161
+ df1 , df2 = df_list [0 ], df_list [1 ]
162
+ if target != col and target and col in df1 .columns and col in df2 .columns :
163
+ col1 , col2 = df_list [0 ][col ], df_list [1 ][col ]
164
+ row_avgs_1 = []
165
+ row_avgs_2 = []
166
+ for names in row_names :
167
+ row_avgs_1 .append (df_list [0 ][target ][col1 == names ].mean ())
168
+ row_avgs_2 .append (df_list [1 ][target ][col2 == names ].mean ())
169
+
170
+ row_avgs_1 = [0 if math .isnan (x ) else x for x in row_avgs_1 ]
171
+ row_avgs_2 = [0 if math .isnan (x ) else x for x in row_avgs_2 ]
172
+ fig .extra_y_ranges = {"Averages" : Range1d (start = min (row_avgs_1 + row_avgs_2 ) * (1 - y_inc ), end = max (row_avgs_1 + row_avgs_2 ) * (1 + y_inc ))}
173
+ fig .multi_line ([row_names , row_names ], [row_avgs_1 , row_avgs_2 ], color = ['navy' , 'firebrick' ], y_range_name = "Averages" , line_width = 4 )
174
+ fig .add_layout (LinearAxis (y_range_name = "Averages" ), 'right' )
145
175
return fig
146
176
147
177
@@ -155,28 +185,56 @@ def hist_viz(
155
185
show_yticks : bool ,
156
186
df_labels : List [str ],
157
187
orig : Optional [List [str ]] = None ,
188
+ target : Optional [str ] = None ,
189
+ df_list : Optional [List [pd .DataFrame ]] = None
158
190
) -> Figure :
159
191
"""
160
192
Render a histogram
161
193
"""
162
194
# pylint: disable=too-many-arguments,too-many-locals
163
-
164
195
tooltips = [
165
196
("Bin" , "@intvl" ),
166
197
("Frequency" , "@freq" ),
167
198
("Percent" , "@pct{0.2f}%" ),
168
199
("Source" , "@orig" ),
169
200
]
201
+ df1 , df2 = df_list [0 ], df_list [1 ]
202
+ y_inc = 0.05
203
+ tooltips = [
204
+ ("Bin" , "@intvl" ),
205
+ ("Frequency" , "@freq" ),
206
+ ("Percent" , "@pct{0.2f}%" ),
207
+ ("Source" , "@orig" ),
208
+ ]
209
+ fig = None
210
+
211
+ y_start , y_end = None , None
212
+ counts_list = []
213
+ if target and target != col and col in df1 .columns and col in df2 .columns :
214
+ for hst in hist :
215
+ counts , bins = hst
216
+ counts_list .append (counts )
217
+
218
+ counts_min_1 = min (counts_list [0 ])
219
+ counts_min_2 = min (counts_list [1 ])
220
+
221
+ counts_max_1 = max (counts_list [0 ])
222
+ counts_max_2 = max (counts_list [1 ])
223
+
224
+ y_start , y_end = min (counts_min_1 , counts_min_2 ), max (counts_max_1 , counts_max_2 )
225
+
226
+
170
227
fig = Figure (
171
228
plot_height = plot_height ,
172
229
plot_width = plot_width ,
173
230
title = col ,
174
231
toolbar_location = None ,
175
- y_axis_type = yscale ,
232
+ y_axis_type = yscale
176
233
)
177
-
234
+ bins_list = []
178
235
for i , hst in enumerate (hist ):
179
236
counts , bins = hst
237
+ bins_list .append (bins )
180
238
if sum (counts ) == 0 :
181
239
fig .rect (x = 0 , y = 0 , width = 0 , height = 0 )
182
240
continue
@@ -192,16 +250,34 @@ def hist_viz(
192
250
}
193
251
)
194
252
bottom = 0 if yscale == "linear" or df .empty else counts .min () / 2
195
- fig .quad (
196
- source = df ,
197
- left = "left" ,
198
- right = "right" ,
199
- bottom = bottom ,
200
- alpha = 0.5 ,
201
- top = "freq" ,
202
- fill_color = CATEGORY10 [i ],
203
- line_color = CATEGORY10 [i ],
204
- )
253
+ if y_start is not None and y_end is not None :
254
+ # fig.y_range = (y_start * (1 - y_inc), y_end * (1 + y_inc))
255
+ fig .extra_y_ranges = {"Counts" : Range1d (start = y_start * (1 - y_inc ), end = y_end * (1 + y_inc ))}
256
+ fig .quad (
257
+ source = df ,
258
+ left = "left" ,
259
+ right = "right" ,
260
+ bottom = bottom ,
261
+ alpha = 0.5 ,
262
+ top = "freq" ,
263
+ fill_color = CATEGORY10 [i ],
264
+ line_color = CATEGORY10 [i ],
265
+ y_range_name = "Counts"
266
+ )
267
+ else :
268
+ fig .quad (
269
+ source = df ,
270
+ left = "left" ,
271
+ right = "right" ,
272
+ bottom = bottom ,
273
+ alpha = 0.5 ,
274
+ top = "freq" ,
275
+ fill_color = CATEGORY10 [i ],
276
+ line_color = CATEGORY10 [i ]
277
+ )
278
+ # if col == 'LotFrontage':
279
+ # breakpoint()
280
+
205
281
hover = HoverTool (tooltips = tooltips , attachment = "vertical" , mode = "vline" )
206
282
fig .add_tools (hover )
207
283
@@ -224,6 +300,34 @@ def hist_viz(
224
300
fig .xaxis .axis_label = x_axis_label
225
301
fig .xaxis .axis_label_standoff = 0
226
302
303
+ if target and target != col and col in df1 .columns and col in df2 .columns :
304
+ col1 , col2 = df1 [col ], df2 [col ]
305
+ source1 , source2 = col1 , col2
306
+ col1 = col1 [~ np .isnan (col1 )]
307
+ col2 = col2 [~ np .isnan (col2 )]
308
+ num_bins1 = len (bins_list [0 ]) - 1
309
+ num_bins2 = len (bins_list [1 ]) - 1
310
+ bins_1 , bins_2 = bins_list [0 ], bins_list [1 ]
311
+
312
+ df1_source_bins_series = pd .cut (source1 , bins = bins_1 , labels = False )
313
+ df1_bin_averages = [None ] * num_bins1
314
+
315
+ df2_source_bins_series = pd .cut (source2 , bins = bins_2 , labels = False )
316
+ df2_bin_averages = [None ] * num_bins2
317
+
318
+ for b in range (num_bins1 ):
319
+ df1_bin_averages [b ] = df1 [target ][df1_source_bins_series == b ].mean ()
320
+ for b in range (num_bins2 ):
321
+ df2_bin_averages [b ] = df2 [target ][df2_source_bins_series == b ].mean ()
322
+
323
+ df1_bin_averages = [0 if math .isnan (x ) else x for x in df1_bin_averages ]
324
+ df2_bin_averages = [0 if math .isnan (x ) else x for x in df2_bin_averages ]
325
+ max_range = max (df1_bin_averages + df2_bin_averages )
326
+ min_range = min (df1_bin_averages + df2_bin_averages )
327
+
328
+ fig .extra_y_ranges ['Averages' ] = Range1d (start = min_range * (1 - y_inc ), end = max_range * (1 + y_inc ))
329
+ fig .multi_line ([bins_1 , bins_2 ], [df1_bin_averages , df2_bin_averages ], color = ['navy' , 'firebrick' ], y_range_name = "Averages" , line_width = 4 )
330
+ fig .add_layout (LinearAxis (y_range_name = "Averages" , axis_label = 'Bin Averages' ), 'right' )
227
331
return fig
228
332
229
333
@@ -610,6 +714,9 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
610
714
nrows = itmdt ["stats" ]["nrows" ]
611
715
titles : List [str ] = []
612
716
717
+ df_list = itmdt .df_list
718
+ target = itmdt .target
719
+
613
720
for col , dtp , data , orig in itmdt ["data" ]:
614
721
fig = None
615
722
if is_dtype (dtp , Nominal ()):
@@ -626,6 +733,8 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
626
733
orig ,
627
734
df_labels ,
628
735
baseline if len (df ) > 1 else 0 ,
736
+ target ,
737
+ df_list
629
738
)
630
739
elif is_dtype (dtp , Continuous ()):
631
740
if cfg .diff .density :
@@ -643,6 +752,8 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
643
752
False ,
644
753
df_labels ,
645
754
orig ,
755
+ target ,
756
+ df_list
646
757
)
647
758
elif is_dtype (dtp , DateTime ()):
648
759
df , timeunit = data
@@ -760,7 +871,6 @@ def render_diff(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
760
871
cfg
761
872
Config instance
762
873
"""
763
-
764
874
if itmdt .visual_type == "comparison_grid" :
765
875
visual_elem = render_comparison_grid (itmdt , cfg )
766
876
if itmdt .visual_type == "comparison_continuous" :
0 commit comments