Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop values without metric readings in ParallelCoordinatesPlot #2898

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ax.utils.common.base import Base
from ax.utils.common.logger import get_logger
from ax.utils.common.result import Err, ExceptionE, Ok, Result
from IPython.display import display
from IPython.display import display, Markdown

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -79,15 +79,23 @@ def _ipython_display_(self) -> None:

By default, this method displays the raw data in a pandas DataFrame.
"""
display(Markdown(f"## {self.title}\n\n### {self.subtitle}"))
display(self.df)


def display_cards(cards: Iterable[AnalysisCard]) -> None:
def display_cards(
cards: Iterable[AnalysisCard], minimum_level: int = AnalysisCardLevel.LOW
) -> None:
"""
Display a collection of AnalysisCards in IPython environments (ex. Jupyter).

Args:
cards: Collection of AnalysisCards to display.
minimum_level: Minimum level of cards to display.
"""
for card in cards:
display(card)
for card in sorted(cards, key=lambda x: x.level, reverse=True):
if card.level >= minimum_level:
display(card)


class Analysis(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/markdown/markdown_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _ipython_display_(self) -> None:
IPython display hook. This is called when the AnalysisCard is printed in an
IPython environment (ex. Jupyter). Here we want to render the Markdown.
"""
display(Markdown(self.blob))
display(Markdown(f"## {self.title}\n\n### {self.subtitle}\n\n{self.blob}"))


class MarkdownAnalysis(Analysis):
Expand Down
7 changes: 2 additions & 5 deletions ax/analysis/plotly/parallel_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _prepare_data(experiment: Experiment, metric: str) -> pd.DataFrame:
for arm in trial.arms
]

return pd.DataFrame.from_records(records)
return pd.DataFrame.from_records(records).dropna()


def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure:
Expand All @@ -96,10 +96,7 @@ def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure:

return go.Figure(
go.Parcoords(
line={
"color": df[metric_name],
"showscale": True,
},
line={"color": df[metric_name], "showscale": True},
dimensions=[
*parameter_dimensions,
{
Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/plotly/plotly_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ax.analysis.analysis import Analysis, AnalysisCard
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from IPython.display import display
from IPython.display import display, Markdown
from plotly import graph_objects as go, io as pio


Expand All @@ -25,6 +25,7 @@ def _ipython_display_(self) -> None:
IPython display hook. This is called when the AnalysisCard is printed in an
IPython environment (ex. Jupyter). Here we want to display the Plotly figure.
"""
display(Markdown(f"## {self.title}\n\n### {self.subtitle}"))
display(self.get_figure())


Expand Down