From 515218df895e2788c9451af6764c1235817fe0f5 Mon Sep 17 00:00:00 2001 From: Auguste Baum Date: Wed, 16 Oct 2024 17:59:11 +0200 Subject: [PATCH] add average score horizontal rule to cross-val chart --- skore/src/skore/item/cross_validation_item.py | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/skore/src/skore/item/cross_validation_item.py b/skore/src/skore/item/cross_validation_item.py index 2976c570..68e0f700 100644 --- a/skore/src/skore/item/cross_validation_item.py +++ b/skore/src/skore/item/cross_validation_item.py @@ -54,23 +54,35 @@ def plot_cross_validation(cv_results: dict) -> altair.Chart: fields=["metric"], bind=input_dropdown, value="test_score" ) - return ( - altair.Chart(df, title="Cross-validation scores per split") - .mark_bar() + average_score_rule = ( + altair.Chart(df) + .mark_rule(strokeWidth=2) .encode( - altair.X("split:N").axis( - title="Split number", - labelAngle=0, - ), - altair.Y("score:Q").axis( - title="Score", - titleAngle=0, - titleAlign="left", - titleX=0, - titleY=-5, - labelLimit=300, - ), - tooltip=["metric:N", "split:N", "score:Q"], + y="mean(score):Q", + tooltip=["mean(score):Q"], + ) + ) + + return ( + ( + altair.Chart(df, title="Cross-validation scores per split") + .mark_bar() + .encode( + altair.X("split:N").axis( + title="Split number", + labelAngle=0, + ), + altair.Y("score:Q").axis( + title="Score", + titleAngle=0, + titleAlign="left", + titleX=0, + titleY=-5, + labelLimit=300, + ), + tooltip=["metric:N", "split:N", "score:Q"], + ) + + average_score_rule ) .add_params(selection) .transform_filter(selection)