Skip to content

Commit

Permalink
leverage Dash/Jupyterlab integration directly in viz (#403)
Browse files Browse the repository at this point in the history
* update minimum `dash` version
* leverage Dash-Jupyterlab support
* cleanup
* add tests
* formatting & linting
* add show back
* keep and deprecate in_background
* use `run` over deprecated `run_server`
* format dostring
* raise DeprecationWarning
* add type hints
* fix type hints
  • Loading branch information
ryanSoley authored Jan 31, 2024
1 parent 1d405c1 commit e60e2df
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/docs-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:

# install rubicon-ml dependencies
# so local pip install doesn't
- dash<=2.14.2,>=2.0.0
- dash<=2.14.2,>=2.11.0
- dash-bootstrap-components<=1.5.0,>=1.0.0
- fsspec<=2023.12.2,>=2021.4.0
- intake[dataframe]<=0.7.0,>=0.5.2
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- s3fs<=2023.12.2,>=0.4

# for viz extras
- dash<=2.14.2,>=2.0.0
- dash<=2.14.2,>=2.11.0
- dash-bootstrap-components<=1.5.0,>=1.0.0

# for testing
Expand Down
122 changes: 75 additions & 47 deletions rubicon_ml/viz/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import threading
import time
import warnings
from typing import Dict, Literal, Optional, Union

import dash_bootstrap_components as dbc
from dash import Dash, html
Expand All @@ -19,7 +18,7 @@ class VizBase:

def __init__(
self,
dash_title="base",
dash_title: str = "base",
):
self.dash_title = f"rubicon-ml: {dash_title}"

Expand Down Expand Up @@ -54,20 +53,34 @@ def load_experiment_data(self):
"extensions of `VizBase` must implement `load_experiment_data(self)`"
)

def register_callbacks(self, link_experiment_table=False):
def register_callbacks(self, link_experiment_table: bool = False):
raise NotImplementedError(
"extensions of `VizBase` must implement `register_callbacks(self)`"
)

def serve(self, in_background=False, dash_kwargs={}, run_server_kwargs={}):
def serve(
self,
in_background: bool = False,
jupyter_mode: Literal["external", "inline", "jupyterlab", "tab"] = "external",
dash_kwargs: Dict = {},
run_server_kwargs: Dict = {},
):
"""Serve the Dash app on the next available port to render the visualization.
Parameters
----------
in_background : bool, optional
True to run the Dash app on a thread and return execution to the
interpreter. False to run the Dash app inline and block execution.
Defaults to False.
DEPRECATED. Background processing is now handled by `jupyter_mode`.
jupyter_mode : "external", "inline", "jupyterlab", or "tab", optional
How to render the dashboard when running from Jupyterlab.
* "external" to serve the dashboard at an external link.
* "inline" to render the dashboard in the current notebook's output
cell.
* "jupyterlab" to render the dashboard in a new window within the
current Jupyterlab session.
* "tab" to serve the dashboard at an external link and open a new
browser tab to said link.
Defaults to "external".
dash_kwargs : dict, optional
Keyword arguments to be passed along to the newly instantiated
Dash object. Available options can be found at
Expand All @@ -79,6 +92,20 @@ def serve(self, in_background=False, dash_kwargs={}, run_server_kwargs={}):
the 'port' argument can be provided here to serve the app on a
specific port.
"""
if in_background:
warnings.warn(
"The `in_background` argument is deprecated and will have no effect, "
"Background processing is now handled by `jupyter_mode`.",
DeprecationWarning,
)

JUPYTER_MODES = ["external", "inline", "jupyterlab", "tab"]
if jupyter_mode not in JUPYTER_MODES:
raise ValueError(
f"Invalid `jupyter_mode` '{jupyter_mode}'. Must be one of "
f"{', '.join(JUPYTER_MODES)}"
)

if self.experiments is None:
raise RuntimeError(
f"`{self.__class__}.experiments` can not be None when `serve` is called"
Expand All @@ -103,39 +130,29 @@ def serve(self, in_background=False, dash_kwargs={}, run_server_kwargs={}):
}
default_run_server_kwargs.update(run_server_kwargs)

_next_available_port = default_run_server_kwargs["port"] + 1

if in_background:
running_server_thread = threading.Thread(
name="run_server",
target=self.app.run_server,
kwargs=default_run_server_kwargs,
)
running_server_thread.daemon = True
running_server_thread.start()
if jupyter_mode != "inline":
default_run_server_kwargs["jupyter_mode"] = jupyter_mode

port = default_run_server_kwargs.get("port")
if "proxy" in run_server_kwargs:
host = default_run_server_kwargs.get("proxy").split("::")[-1]
else:
host = f"http://localhost:{port}"
_next_available_port = default_run_server_kwargs["port"] + 1

time.sleep(0.1) # wait for thread to see if requested port is available
if not running_server_thread.is_alive():
raise RuntimeError(f"port {port} may already be in use")
self.app.run(**default_run_server_kwargs)

return host
else:
self.app.run_server(**default_run_server_kwargs)
def show(
self,
i_frame_kwargs: Dict = {},
dash_kwargs: Dict = {},
run_server_kwargs: Dict = {},
height: Optional[Union[int, str]] = None,
width: Optional[Union[int, str]] = None,
):
"""Serve the Dash app on the next available port to render the visualization.
def show(self, i_frame_kwargs={}, dash_kwargs={}, run_server_kwargs={}):
"""Show the Dash app inline in a Jupyter notebook.
Additionally, renders the visualization inline in the current Jupyter notebook.
Parameters
----------
i_frame_kwargs : dict, optional
Keyword arguments to be passed along to the newly instantiated
IFrame object. Available options include 'height' and 'width'.
i_frame_kwargs: dict, optional
DEPRECATED. Use `height` and `width` instead.
dash_kwargs : dict, optional
Keyword arguments to be passed along to the newly instantiated
Dash object. Available options can be found at
Expand All @@ -146,18 +163,29 @@ def show(self, i_frame_kwargs={}, dash_kwargs={}, run_server_kwargs={}):
https://dash.plotly.com/reference#app.run_server. Most commonly,
the 'port' argument can be provided here to serve the app on a
specific port.
height : int, str or None, optional
The height of the inline visualizaiton. Integers represent number
of pixels, strings represent a percentage of the window and must
end with '%'.
width : int, str or None, optional
The width of the inline visualizaiton. Integers represent number
of pixels, strings represent a percentage of the window and must
end with '%'.
"""
from IPython.display import IFrame
if i_frame_kwargs:
warnings.warn(
"The `i_frame_kwargs` argument is deprecated and will have no effect, "
"use `height` and `width` instead.",
DeprecationWarning,
)

host = self.serve(
in_background=True, dash_kwargs=dash_kwargs, run_server_kwargs=run_server_kwargs
)
proxied_host = os.path.join(host, self.app.config["requests_pathname_prefix"].lstrip("/"))
if height is not None:
run_server_kwargs["jupyter_height"] = height
if width is not None:
run_server_kwargs["jupyter_width"] = width

default_i_frame_kwargs = {
"height": "600px",
"width": "100%",
}
default_i_frame_kwargs.update(i_frame_kwargs)

return IFrame(proxied_host, **default_i_frame_kwargs)
self.serve(
jupyter_mode="inline",
dash_kwargs=dash_kwargs,
run_server_kwargs=run_server_kwargs,
)
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ prefect =
s3 =
s3fs<=2023.12.2,>=0.4
ui =
dash<=2.14.2,>=2.0.0
dash<=2.14.2,>=2.11.0
dash-bootstrap-components<=1.5.0,>=1.0.0
viz =
dash<=2.14.2,>=2.0.0
dash<=2.14.2,>=2.11.0
dash-bootstrap-components<=1.5.0,>=1.0.0
all =
dash<=2.14.2,>=2.0.0
dash<=2.14.2,>=2.11.0
dash-bootstrap-components<=1.5.0,>=1.0.0
prefect<=1.2.4,>=0.12.0
s3fs<=2023.12.2,>=0.4
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/viz/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from dash import Dash, html

from rubicon_ml.viz.base import VizBase
Expand All @@ -19,3 +20,24 @@ def test_base_build_layout():

assert len(layout) == 8
assert layout.children.children[-1].children.id == "test-layout"


def test_base_serve_jupyter_mode_errors():
base = VizBaseTest()
base.build_layout()

with pytest.raises(ValueError) as error:
base.serve(jupyter_mode="invalid")

assert "Invalid `jupyter_mode`" in str(error)


def test_base_serve_missing_experiment_errors():
base = VizBaseTest()
base.build_layout()
base.experiments = None

with pytest.raises(RuntimeError) as error:
base.serve()

assert "experiments` can not be None" in str(error)

0 comments on commit e60e2df

Please sign in to comment.